import torch
import numpy as np
from tqdm import tqdm

from ppl_dataset import get_loaders

def ppl_metric(model, tokenizer, datasets, seq_len=128, batch_size = 4, device="cuda"):
    metric = {}
    for dataset in datasets:
        _, test_loader = get_loaders(dataset, tokenizer, seq_len=seq_len, batch_size = batch_size)
        ppl = llama_eval(model, test_loader, device)
        metric[dataset] = ppl
        print(metric)
    return metric

@torch.no_grad()
def llama_eval(model, test_lodaer, device):
    nlls = []
    n_samples = 0
    for batch in tqdm(test_lodaer):
        batch = batch.to(device)
        output = model(batch)
        lm_logits = output.logits
    
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()
        
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        nlls.append(loss)
        
    ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
    return ppl.item()


import time

def measure_generation_latency(model, tokenizer, prompt, gen_length=128, batch_size=1, device='cuda'):
    # Tokenize batch inputs
    input_ids = tokenizer([prompt] * batch_size, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

    # Warm-up (optional)
    with torch.no_grad():
        _ = model.generate(input_ids=input_ids, max_new_tokens=4, do_sample=False, use_cache=True)

    # Latency measurement
    torch.cuda.synchronize()
    start = time.time()
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=gen_length,
            do_sample=False,
            use_cache=True
        )
    torch.cuda.synchronize()
    end = time.time()

    latency = end - start
    total_generated_tokens = batch_size * gen_length
    throughput = total_generated_tokens / latency

    throughput = total_generated_tokens / latency if latency > 0 else float("nan")

    print(f"[Latency]    Batch size: {batch_size}, Generation length: {gen_length} → {latency:.3f} seconds")
    print(f"[Throughput] {throughput:.2f} tokens/sec")

    return latency, throughput


